#! /usr/bin/env python3

import argparse
import logging
import numpy as np

import pymetis

import urlearning.ScoreCache
import urlearning.TarjansAlgorithm
import urlearning.TopPPopsConstraint

import misc.utils as utils

default_num_clusters = 2
default_max_pattern_size = 21
default_max_dynamic_pd_size = 0

def adjacency_matrix2list(adjacency_matrix):
    adjacency_list = [np.where(adjacency_matrix[i] == 1)[0] for i in range(len(adjacency_matrix))]
    return adjacency_list

def get_metis_partition(dgf, num_clusters):
    adjacency_list = adjacency_matrix2list(dgf.adjacencyMatrix)

    msg = "adjacency list: {}".format(adjacency_list)
    logging.debug(msg)

    res = pymetis.part_graph(num_clusters, adjacency=adjacency_list)
    clustering = np.array(res[1])

    v = []
    for i in range(num_clusters):
        v_i = np.where(clustering == i)[0]
        v.append(v_i)

    return v

def get_neighbor_sccs(v_i, used_sccs, sccs, graph):
    neighbor_sccs = []
    
    v_i_indices = np.where(v_i == 1)[0]
    v_i_submatrix = graph[v_i_indices]
    
    
    for i, neighbor_scc in enumerate(sccs):
        if used_sccs[i]:
            continue
            
        neighbor_scc = np.where(neighbor_scc == 1)[0]
        
        # i have the indices for "this" scc
        # and those for the "neighbor"
        
        # are there any edges in the graph between them?
        
        # we assume the graph is already symmetrized
        neighbor_submatrix = v_i_submatrix[:,neighbor_scc]
        
        if np.sum(neighbor_submatrix) > 0:
            neighbor_sccs.append(i)
            
    return np.array(neighbor_sccs, dtype=int)

def get_component_grouping_partition(dgf, sccs, max_pattern_size):
    scc_sizes = np.sum(sccs, axis=1)
    largest_scc_index = np.argmax(scc_sizes)

    # keep track of the variables in the current variable group
    v_i = sccs[largest_scc_index]
    v_i_size = np.sum(v_i)

    # keep track of remaining SCCs
    used_sccs = np.zeros(len(sccs), dtype=bool)
    used_sccs[largest_scc_index] = True

    # and store the actual patterns
    v = []

    while v_i_size <= max_pattern_size and np.sum(used_sccs) < len(sccs):
        msg = "Processing SCC: {}".format(v_i)
        logging.debug(msg)

        # find the unused SCCs which neighbor the current group
        possible_neighbors = get_neighbor_sccs(v_i, used_sccs, sccs, 
            dgf.adjacencyMatrix)
        
        # sort the neighbors based on size
        neighbor_sizes = scc_sizes[possible_neighbors]
        sorted_neighbors = np.argsort(neighbor_sizes)[::-1]
        sorted_neighbors = possible_neighbors[sorted_neighbors]
        sorted_neighbor_sizes = np.sort(neighbor_sizes)[::-1]
        
        msg = "sorted_neighbor_sizes: {}".format(sorted_neighbor_sizes)
        logging.debug(msg)
        
        # now, check which neighbors we can add
        added = False
        for i, size in enumerate(sorted_neighbor_sizes):
            if size + v_i_size <= max_pattern_size:
                
                # get the index of the neighbor SCC and mark it as used
                neighbor = sorted_neighbors[i]
                used_sccs[neighbor] = True
                
                # and get the actual variables
                neighbor = sccs[neighbor]
                
                # add them to the current grouping
                v_i = np.sum([v_i, neighbor], axis=0)
                v_i_size += size
                
                # and continue to check for the next SCC
                added = True
                break
                
        # if we added something, then just keep going
        if added:
            continue
            
        # otherwise, add this to v
        v.append(v_i)
        
        # now, get the next largest component
        remaining_sizes = scc_sizes[~used_sccs]
        remaining_indices = np.where(used_sccs == 0)[0]
        
        largest_scc_index = np.argmax(remaining_sizes)
        largest_scc_index = remaining_indices[largest_scc_index]
        used_sccs[largest_scc_index] = True

        # keep track of the variables in the current variable group
        v_i = sccs[largest_scc_index]
        v_i_size = np.sum(v_i)
        
    # check if we added all of the SCCs
    if np.sum(used_sccs) == len(sccs):
        v.append(v_i)
    else:
        # in this case, the largest SCC was too large,
        # so there is nothing we can do
        return None

    # finally, check if we can merge the smaller SCCs (even if they are not neighbors)
    for i in range(len(v)-1, -1, -1):
        if sum(v[i]) + sum(v[i-1]) <= max_pattern_size:
            v[i-1] = np.sum([v[i], v[i-1]], axis=0)
            v.pop(i)
        else:
            break

    for i,v_i in enumerate(v):
        v[i] = np.where(v_i == 1)[0]

    return v

def write_static_pd(pd, out):
    out.write("s;")
    s = ";".join([
                    ",".join(str(x) for x in group)
                        for group in pd])
    out.write(s)
    out.write("\n")


def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description="This script creates the component grouping pattern databases from "
        "(Fan and Yuan, AAAI 2015). If that does not work, it reverts to a simple unweighted "
        "version of the parent grouping.")
    parser.add_argument('pss', help="The pss file containing the local scores")
    parser.add_argument('out', help="The output pattern database file")
    parser.add_argument('-d', '--max-dynamic-pd-size', help="The maximum size of the patterns "
        "used in dynamic pattern databases", type=int, default=default_max_dynamic_pd_size)
    parser.add_argument('-c', '--num-clusters', help="The number of clusters to use, but "
        "only in case of the parent grouping.", type=int, default=default_num_clusters) 
    parser.add_argument('-s', '--max-pattern-size', help="The maximum size allowed in "
        "creating the component grouping pattern database", type=int,default=default_max_pattern_size)
    
    utils.add_logging_options(parser)
    args = parser.parse_args()
    utils.update_logging(args)

    msg = "Reading score file"
    logging.info(msg)

    file_type = 'pss'
    if args.pss.endswith('jkl'):
        file_type = 'jkl'
    sc = urlearning.ScoreCache.ScoreCache(args.pss, file_type=file_type)

    # create the helpers
    tppc = urlearning.TopPPopsConstraint.TopPPopsConstraint()
    tarjans = urlearning.TarjansAlgorithm.TarjansAlgorithm()

    all_pattern_databases = []

    # find as many pattern databases as we can
    for p in range(1, sc.getMaxScoreCount()):
        msg = "Using top-{} constraint".format(p)
        logging.info(msg)

        constrained_graph = tppc.createConstrainedGraph(sc, p=p)
        sccs = tarjans.getSCCs(constrained_graph)
        constrained_graph.symmetrizeMax()

        msg = "Searching with metis"
        logging.info(msg)

        v = get_metis_partition(constrained_graph, args.num_clusters)
        all_pattern_databases.append(v)
        
        msg = "Searching with component grouping"
        logging.info(msg)

        v = get_component_grouping_partition(constrained_graph, sccs, args.max_pattern_size)
        if v is not None:
            all_pattern_databases.append(v)

            if len(v) == 1:
                # also, check if we had only a single component
                # then just quite
                msg = "All of the variables fit in one group."
                logging.info(msg)
                break
        else:
            msg = "Could not find a valid component grouping"
            logging.info(msg)
            break

    # and write them out
    msg = "Writing {} static pattern databases to disk".format(len(all_pattern_databases))
    logging.info(msg)
    
    with open(args.out, 'w') as out:
        for pd in all_pattern_databases:
            write_static_pd(pd, out)

        for d in range(2, args.max_dynamic_pd_size+1):
            out.write("d;{}\n".format(d))

if __name__ == '__main__':
    main()
